--- external/stylegan/training/dataset.py	2023-04-06 03:45:07.254378718 +0000
+++ external_reference/stylegan/training/dataset.py	2023-04-06 03:41:03.338611206 +0000
@@ -21,6 +21,9 @@
 except ImportError:
     pyspng = None
 
+import random
+from utils import midas, camera_util
+
 #----------------------------------------------------------------------------
 
 class Dataset(torch.utils.data.Dataset):
@@ -85,14 +88,60 @@
         return self._raw_idx.size
 
     def __getitem__(self, idx):
-        image = self._load_raw_image(self._raw_idx[idx])
+        image_info = self._load_raw_image(self._raw_idx[idx])
+        image = image_info['rgb']
+        depth = image_info['depth']
+        acc = image_info['acc']
+        K = image_info['K']
+        Rt = image_info['Rt']
+        original = image_info['orig']
+
+        # handle masked image flip
         assert isinstance(image, np.ndarray)
         assert list(image.shape) == self.image_shape
         assert image.dtype == np.uint8
         if self._xflip[idx]:
             assert image.ndim == 3 # CHW
             image = image[:, :, ::-1]
-        return image.copy(), self.get_label(idx)
+        image = image.copy()
+
+        # handle original image flip
+        assert isinstance(original, np.ndarray)
+        assert list(original.shape) == self.image_shape
+        assert original.dtype == np.uint8
+        if self._xflip[idx]:
+            assert original.ndim == 3 # CHW
+            original = original[:, :, ::-1]
+        original = original.copy()
+
+        # handle depth flip
+        assert isinstance(depth, np.ndarray)
+        assert list(depth.shape)[1:] == self.image_shape[1:] # depth has one channel
+        if self._xflip[idx]:
+            assert depth.ndim == 3 # CHW
+            depth = depth[:, :, ::-1]
+        depth = depth.copy()
+
+        # handle mask flip
+        assert isinstance(acc, np.ndarray)
+        assert list(acc.shape)[1:] == self.image_shape[1:] # acc has one channel
+        if self._xflip[idx]:
+            assert acc.ndim == 3 # CHW
+            acc = acc[:, :, ::-1]
+        acc = acc.copy()
+
+        # check intrinisics and extrinsics
+        assert isinstance(K, np.ndarray)
+        assert list(K.shape) == [3, 3]
+        assert isinstance(Rt, np.ndarray)
+        assert list(Rt.shape) == [4, 4]
+
+        # get flipped images
+        image_info['rgb'] = image
+        image_info['orig'] = original
+        image_info['depth'] = depth
+        image_info['acc'] = acc
+        return image_info, self.get_label(idx)
 
     def get_label(self, idx):
         label = self._get_raw_labels()[self._raw_idx[idx]]
@@ -157,14 +206,45 @@
     def __init__(self,
         path,                   # Path to directory or zip.
         resolution      = None, # Ensure specific resolution, None = highest available.
+        pose_path       = None, # Path to training pose distribution
+        depth_scale     = 16,   # scale factor for depth
+        depth_clip      = 20,   # clip all depths above this value
+        use_disp        = True, # use inverse depth if true
+        fov_mean        = 60,   # intrinsics mean FOV
+        fov_std         = 0,    # intrinsics std FOV
+        mask_downsample = 'antialias', # downsample mode for mask (antialias is softer boundary)
         **super_kwargs,         # Additional arguments for the Dataset base class.
     ):
         self._path = path
         self._zipfile = None
+        self._depth_path = path.replace('img', 'dpt_depth')
+        self._seg_path = path.replace('img', 'dpt_sky')
+        self.depth_scale = depth_scale
+        self.depth_clip = depth_clip
+        self.use_disp = use_disp
+        self.pose_path = pose_path
+        if self.pose_path is not None:
+            data = torch.load(self.pose_path)
+            self.Rt = data['Rts'].float().numpy()
+            self.cameras = data['cameras']
+        self.fov_mean = fov_mean
+        self.fov_std = fov_std
+        self.mask_downsample = mask_downsample
 
         if os.path.isdir(self._path):
             self._type = 'dir'
-            self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
+            if os.path.isfile(self._path + '_cache.txt'):
+                with open(self._path + '_cache.txt') as cache:
+                    self._all_fnames = set([line.strip() for line in cache])
+            else:
+                print("Walking dataset...")
+                self._all_fnames = [os.path.relpath(os.path.join(root, fname), start=self._path)
+                                    for root, _dirs, files in
+                                    os.walk(self._path, followlinks=True) for fname in files]
+                with open(self._path + '_cache.txt', 'w') as cache:
+                    [cache.write("%s\n" % fname) for fname in self._all_fnames]
+                self._all_fnames = set(self._all_fnames)
+                print("done walking")
         elif self._file_ext(self._path) == '.zip':
             self._type = 'zip'
             self._all_fnames = set(self._get_zipfile().namelist())
@@ -177,9 +257,14 @@
             raise IOError('No image files found in the specified path')
 
         name = os.path.splitext(os.path.basename(self._path))[0]
-        raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
-        if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
-            raise IOError('Image files do not match the specified resolution')
+        if resolution is not None:
+            raw_shape = [len(self._image_fnames)] + [3, resolution, resolution] # list(self._load_raw_image(0).shape)
+        else:
+            # do not resize it to determine initial shape
+            raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0, resize=False)[0].shape)
+        # raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
+        # if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
+        #     raise IOError('Image files do not match the specified resolution')
         super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
 
     @staticmethod
@@ -209,17 +294,111 @@
     def __getstate__(self):
         return dict(super().__getstate__(), _zipfile=None)
 
-    def _load_raw_image(self, raw_idx):
+    def _load_raw_image(self, raw_idx, resize=True):
         fname = self._image_fnames[raw_idx]
+        ### load image
         with self._open_file(fname) as f:
-            if pyspng is not None and self._file_ext(fname) == '.png':
-                image = pyspng.load(f.read())
-            else:
-                image = np.array(PIL.Image.open(f))
+            image = PIL.Image.open(f).convert('RGB')
+        w, h = image.size
+        ### load depth map
+        depth_path = os.path.join(self._depth_path, fname.replace('png', 'pfm'))
+        disp, scale = midas.read_pfm(depth_path)
+        # normalize 0 to 1
+        disp = np.array(disp)
+        dmmin = np.percentile(disp, 1)
+        dmmax = np.percentile(disp, 99)
+        scaled_disp = (disp-dmmin) / (dmmax-dmmin + 1e-6)
+        scaled_disp = np.clip(scaled_disp, 0., 1.) * 255
+        disp_img = PIL.Image.fromarray(scaled_disp.astype(np.uint8))
+        # disparity mask needs to be done at full resolution
+        disp_mask_np = (scaled_disp/255 > 1/self.depth_clip).astype(np.uint8)
+        ### load sky mask
+        sky_path = os.path.join(self._seg_path, fname.replace('png', 'npz'))
+        sky_mask = np.load(sky_path)['sky_mask']
+        sky_img = PIL.Image.fromarray(sky_mask * 255)
+        ### remove sky from full size image (prevent sky color from leaking when downsampled)
+        image_np = np.array(image)
+        gray = np.array([128, 128, 128]).reshape(1, 1, -1)
+        image_gray = (image_np * sky_mask[..., None] + gray * (1-sky_mask[..., None])).astype(np.uint8)
+        image_gray = PIL.Image.fromarray(image_gray)
+
+        if resize:
+            # note: the input images should be square
+            assert(image.size[0] == image.size[1])
+            assert(self.image_shape[1] == self.image_shape[2])
+            target_size = self.image_shape[1:]
+            if image.size != target_size:
+                image = image.resize(target_size, PIL.Image.ANTIALIAS)
+                image_gray = image_gray.resize(target_size, PIL.Image.ANTIALIAS)
+                disp_img = disp_img.resize(target_size, PIL.Image.ANTIALIAS)
+                sky_img = sky_img.resize(target_size, PIL.Image.NEAREST if self.mask_downsample=='nearest' else PIL.Image.ANTIALIAS) 
+
+        # handle image dimensions
+        image = np.array(image)
         if image.ndim == 2:
             image = image[:, :, np.newaxis] # HW => HWC
         image = image.transpose(2, 0, 1) # HWC => CHW
-        return image
+        # handle image with sky mask dimensions
+        image_gray = np.array(image_gray)
+        if image_gray.ndim == 2:
+            image_gray = image_gray[:, :, np.newaxis] # HW => HWC
+        image_gray = image_gray.transpose(2, 0, 1) # HWC => CHW
+        # handle disp dimensions
+        disp = np.array(disp_img)
+        if disp.ndim == 2:
+            disp = disp[:, :, np.newaxis] # HW => HWC
+        disp = disp.transpose(2, 0, 1) # HWC => CHW
+        # handle sky mask dimensions
+        mask = np.array(sky_img)
+        if mask.ndim == 2:
+            mask = mask[:, :, np.newaxis] # HW => HWC
+        mask = mask.transpose(2, 0, 1) # HWC => CHW
+
+        # process mask
+        mask = mask / 255
+
+        # process disparity map (clip and rescale, to match nerf far)
+        disp = disp / 255 # convert back to [0, 1] range
+        disp_clipped = np.clip(disp, 1/self.depth_clip, 1) # range: [1/clip, 1]
+        psuedo_depth = 1/disp_clipped - 1 # range:[0, clip-1]
+        max_depth = self.depth_clip - 1
+        scaled_depth = psuedo_depth / max_depth * (self.depth_scale - 1) # range: [0, depth_scale-1]
+        scaled_disp = 1/(scaled_depth+1) # range: [1/depth_scale, 1]
+
+        # multiply everything by the downsampled mask
+        scaled_disp = scaled_disp * mask
+        scaled_depth = scaled_depth * mask
+        gray = np.array([128, 128, 128]).reshape(-1, 1, 1)
+        rgb_masked = (image_gray * mask + gray * (1-mask)).astype(np.uint8)
+
+        # intrinsics
+        K = np.zeros((3, 3))
+        fov = self.fov_mean + self.fov_std * np.random.randn()
+
+        fx = (self.image_shape[2] / 2) / np.tan(np.deg2rad(fov) / 2)
+        fy = (self.image_shape[1] / 2) / np.tan(np.deg2rad(fov) / 2)
+        K[0, 0] = fx
+        K[1, 1] = -fy
+        K[2, 2] = -1
+
+        # extrinsics
+        if self.pose_path is not None:
+            idx = random.randint(0, self.Rt.shape[0]-1)
+            Rt = self.Rt[idx].astype(np.float64)[0]
+            camera = self.cameras[idx]
+        else:
+            Rt = np.eye(4)
+            camera = camera_util.Camera(0., 0., 0., 0., 0.)
+
+        return {'rgb': rgb_masked,
+                'depth': scaled_disp if self.use_disp else scaled_depth,
+                'acc': mask,
+                'K': K, # 3x3
+                'Rt': Rt, #4x4
+                'global_size': self.image_shape[-1],
+                'fov': fov,
+                'orig': image,
+               }
 
     def _load_raw_labels(self):
         fname = 'dataset.json'
